{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Outline\n",
    "\n",
    "* Today we will show how to train a ConvNet using PyTorch\n",
    "* We will also illustrate how the ConvNet makes use of specific assumptions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# To perform well, we need to incorporate some prior knowledge about the problem\n",
    "\n",
    "* Assumptions helps us when they are true\n",
    "* They hurt us when they are not\n",
    "* We want to make just the right amount of assumptions, not more than that\n",
    "\n",
    "## In Deep Learning\n",
    "\n",
    "* Many layers: compositionality\n",
    "* Convolutions: locality + stationarity of images\n",
    "* Pooling: Invariance of object class to translations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from res.plot_lib import plot_data, plot_model, set_default"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_default()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy\n",
    "\n",
    "# function to count number of parameters\n",
    "def get_n_params(model):\n",
    "    np=0\n",
    "    for p in list(model.parameters()):\n",
    "        np += p.nelement()\n",
    "    return np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the Dataset (MNIST)\n",
    "\n",
    "\n",
    "We can use some PyTorch DataLoader utilities for this. This will download, shuffle, normalize data and arrange it in batches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_size  = 28*28   # images are 28x28 pixels\n",
    "output_size = 10      # there are 10 classes\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=True, download=True,\n",
    "                   transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1307,), (0.3081,))\n",
    "                   ])),\n",
    "    batch_size=64, shuffle=True)\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    datasets.MNIST('../data', train=False, transform=transforms.Compose([\n",
    "                       transforms.ToTensor(),\n",
    "                       transforms.Normalize((0.1307,), (0.3081,))\n",
    "                   ])),\n",
    "    batch_size=1000, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show some images\n",
    "plt.figure(figsize=(16, 6))\n",
    "for i in range(10):\n",
    "    plt.subplot(2, 5, i + 1)\n",
    "    image, _ = train_loader.dataset.__getitem__(i)\n",
    "    plt.imshow(image.squeeze().numpy())\n",
    "    plt.axis('off');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create the model classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FC2Layer(nn.Module):\n",
    "    def __init__(self, input_size, n_hidden, output_size):\n",
    "        super(FC2Layer, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.network = nn.Sequential(\n",
    "            nn.Linear(input_size, n_hidden), \n",
    "            nn.ReLU(), \n",
    "            nn.Linear(n_hidden, n_hidden), \n",
    "            nn.ReLU(), \n",
    "            nn.Linear(n_hidden, output_size), \n",
    "            nn.LogSoftmax(dim=1)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, self.input_size)\n",
    "        return self.network(x)\n",
    "    \n",
    "class CNN(nn.Module):\n",
    "    def __init__(self, input_size, n_feature, output_size):\n",
    "        super(CNN, self).__init__()\n",
    "        self.n_feature = n_feature\n",
    "        self.conv1 = nn.Conv2d(in_channels=1, out_channels=n_feature, kernel_size=5)\n",
    "        self.conv2 = nn.Conv2d(n_feature, n_feature, kernel_size=5)\n",
    "        self.fc1 = nn.Linear(n_feature*4*4, 50)\n",
    "        self.fc2 = nn.Linear(50, 10)\n",
    "        \n",
    "    def forward(self, x, verbose=False):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, kernel_size=2)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, kernel_size=2)\n",
    "        x = x.view(-1, self.n_feature*4*4)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = F.log_softmax(x, dim=1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running on a GPU: device string\n",
    "\n",
    "Switching between CPU and GPU in PyTorch is controlled via a device string, which will seemlessly determine whether GPU is available, falling back to CPU if not:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_list = []\n",
    "\n",
    "def train(epoch, model, perm=torch.arange(0, 784).long()):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        # send to device\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        \n",
    "        # permute pixels\n",
    "        data = data.view(-1, 28*28)\n",
    "        data = data[:, perm]\n",
    "        data = data.view(-1, 1, 28, 28)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = F.nll_loss(output, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if batch_idx % 100 == 0:\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * len(data), len(train_loader.dataset),\n",
    "                100. * batch_idx / len(train_loader), loss.item()))\n",
    "            \n",
    "def test(model, perm=torch.arange(0, 784).long()):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    for data, target in test_loader:\n",
    "        # send to device\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        \n",
    "        # permute pixels\n",
    "        data = data.view(-1, 28*28)\n",
    "        data = data[:, perm]\n",
    "        data = data.view(-1, 1, 28, 28)\n",
    "        output = model(data)\n",
    "        test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss                                                               \n",
    "        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability                                                                 \n",
    "        correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()\n",
    "\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    accuracy = 100. * correct / len(test_loader.dataset)\n",
    "    accuracy_list.append(accuracy)\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(test_loader.dataset),\n",
    "        accuracy))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train a small fully-connected network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_hidden = 8 # number of hidden units\n",
    "\n",
    "model_fnn = FC2Layer(input_size, n_hidden, output_size)\n",
    "model_fnn.to(device)\n",
    "optimizer = optim.SGD(model_fnn.parameters(), lr=0.01, momentum=0.5)\n",
    "print('Number of parameters: {}'.format(get_n_params(model_fnn)))\n",
    "\n",
    "for epoch in range(0, 1):\n",
    "    train(epoch, model_fnn)\n",
    "    test(model_fnn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train a ConvNet with the same number of parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training settings \n",
    "n_features = 6 # number of feature maps\n",
    "\n",
    "model_cnn = CNN(input_size, n_features, output_size)\n",
    "model_cnn.to(device)\n",
    "optimizer = optim.SGD(model_cnn.parameters(), lr=0.01, momentum=0.5)\n",
    "print('Number of parameters: {}'.format(get_n_params(model_cnn)))\n",
    "\n",
    "for epoch in range(0, 1):\n",
    "    train(epoch, model_cnn)\n",
    "    test(model_cnn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The ConvNet performs better with the same number of parameters, thanks to its use of prior knowledge about images\n",
    "\n",
    "* Use of convolution: Locality and stationarity in images\n",
    "* Pooling: builds in some translation invariance\n",
    "\n",
    "# What happens if the assumptions are no longer true?\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perm = torch.randperm(784)\n",
    "plt.figure(figsize=(16, 12))\n",
    "for i in range(10):\n",
    "    image, _ = train_loader.dataset.__getitem__(i)\n",
    "    # permute pixels\n",
    "    image_perm = image.view(-1, 28*28).clone()\n",
    "    image_perm = image_perm[:, perm]\n",
    "    image_perm = image_perm.view(-1, 1, 28, 28)\n",
    "    plt.subplot(4, 5, i + 1)\n",
    "    plt.imshow(image.squeeze().numpy())\n",
    "    plt.axis('off')\n",
    "    plt.subplot(4, 5, i + 11)\n",
    "    plt.imshow(image_perm.squeeze().numpy())\n",
    "    plt.axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ConvNet with permuted pixels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training settings \n",
    "n_features = 6 # number of feature maps\n",
    "\n",
    "model_cnn = CNN(input_size, n_features, output_size)\n",
    "model_cnn.to(device)\n",
    "optimizer = optim.SGD(model_cnn.parameters(), lr=0.01, momentum=0.5)\n",
    "print('Number of parameters: {}'.format(get_n_params(model_cnn)))\n",
    "\n",
    "for epoch in range(0, 1):\n",
    "    train(epoch, model_cnn, perm)\n",
    "    test(model_cnn, perm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fully-Connected with Permuted Pixels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_hidden = 8    # number of hidden units\n",
    "\n",
    "model_fnn = FC2Layer(input_size, n_hidden, output_size)\n",
    "model_fnn.to(device)\n",
    "optimizer = optim.SGD(model_fnn.parameters(), lr=0.01, momentum=0.5)\n",
    "print('Number of parameters: {}'.format(get_n_params(model_fnn)))\n",
    "\n",
    "for epoch in range(0, 1):\n",
    "    train(epoch, model_fnn, perm)\n",
    "    test(model_fnn, perm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The ConvNet's performance drops when we permute the pixels, but the Fully-Connected Network's performance stays the same\n",
    "\n",
    "* ConvNet makes the assumption that pixels lie on a grid and are stationary/local\n",
    "* It loses performance when this assumption is wrong\n",
    "* The fully-connected network does not make this assumption\n",
    "* It does less well when it is true, since it doesn't take advantage of this prior knowledge\n",
    "* But it doesn't suffer when the assumption is wrong"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.bar(('NN image', 'CNN image',\n",
    "         'CNN scrambled', 'NN scrambled'),\n",
    "        accuracy_list, width=0.4)\n",
    "plt.ylim((min(accuracy_list)-5, 96))\n",
    "plt.ylabel('Accuracy [%]')\n",
    "for tick in plt.gca().xaxis.get_major_ticks():\n",
    "    tick.label.set_fontsize(20)\n",
    "plt.title('Performance comparison');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Checking Model Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dir(model_cnn))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dir(model_fnn))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dl-minicourse] *",
   "language": "python",
   "name": "conda-env-dl-minicourse-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
